import os
import asyncio
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
# NOTE: you must use langchain-core >= 0.3 with Pydantic v2
from pydantic import BaseModel

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from copilotkit import CopilotKitState
from copilotkit.langchain import copilotkit_customize_config

# 导入我们的 Facade
from .llm_mcp_facade import create_deepseek_mcp_facade
from .domain_mcp_loader import get_mcp_servers_config_by_domain, get_tools_by_domain
from .db_config import get_db_config

# 临时导入，用于兼容现有代码
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_deepseek import ChatDeepSeek

class State(CopilotKitState):
    # This flag is new
    ask_human: bool


class RequestAssistance(BaseModel):
    """如果你回答用户问题是还需要用户补充信息，可以调用这个工具.
    """

    request: str


async def get_tools():
    """获取所有工具（包括 Tavily 搜索和 MCP 工具）"""
    # Tavily 搜索工具
    tavily_tool = TavilySearchResults(max_results=2)
    tools = [tavily_tool]

    # 尝试添加 MCP 工具
    try:
        # 直接使用 get_tools_by_domain 获取 MCP 工具
        mcp_tools = await get_tools_by_domain("law")
        tools.extend(mcp_tools)
        print(f"成功加载 {len(mcp_tools)} 个 MCP 工具")

    except Exception as e:
        print(f"加载 MCP 工具失败: {e}")
        print("继续使用基础工具...")

    return tools

# 初始化工具（这将在运行时调用）
_tools = None



# LLM 配置
api_base = os.getenv("DS_BASE_URL")
llm_kwargs = {
    "model": "deepseek-chat",
    "api_key": os.getenv("DS_API_KEY"),
}
if api_base:
    llm_kwargs["api_base"] = api_base

llm = ChatDeepSeek(**llm_kwargs)


async def chatbot(state: State, config: RunnableConfig):
    # 配置 CopilotKit 发送工具调用信息
    config = copilotkit_customize_config(config, emit_tool_calls=True)

    # 获取工具并绑定到 LLM
    tools = await get_tools()
    llm_with_tools = llm.bind_tools(tools + [RequestAssistance])

    response = llm_with_tools.invoke(state["messages"], config=config)
    ask_human = False
    if (
        response.tool_calls
        and response.tool_calls[0]["name"] == RequestAssistance.__name__
    ):
        ask_human = True
    return {"messages": [response], "ask_human": ask_human}


async def create_graph():
    """创建并返回图"""
    tools = await get_tools()

    graph_builder = StateGraph(State)

    graph_builder.add_node("chatbot", chatbot)
    graph_builder.add_node("tools", ToolNode(tools=tools))

    return graph_builder, tools

# 全局变量用于存储图
_graph = None
_graph_domain = None


def create_response(response: str, ai_message: AIMessage):
    return ToolMessage(
        content=response,
        tool_call_id=ai_message.tool_calls[0]["id"],
    )


def human_node(state: State):
    new_messages = []
    if not isinstance(state["messages"][-1], ToolMessage):
        # Typically, the user will have updated the state during the interrupt.
        # If they choose not to, we will include a placeholder ToolMessage to
        # let the LLM continue.
        new_messages.append(
            create_response("No response from human.", state["messages"][-1])
        )
    return {
        # Append the new messages
        "messages": new_messages,
        # Unset the flag
        "ask_human": False,
    }


def select_next_node(state: State):
    if state["ask_human"]:
        return "human"
    # Otherwise, we can route as before
    return tools_condition(state)


async def get_compiled_graph():
    """获取编译后的图"""
    global _graph    
   
    
    if _graph is None:
        graph_builder, _ = await create_graph()

        graph_builder.add_node("human", human_node)

        graph_builder.add_conditional_edges(
            "chatbot",
            select_next_node,
            {"human": "human", "tools": "tools", "__end__": "__end__"},
        )
        graph_builder.add_edge("tools", "chatbot")
        graph_builder.add_edge("human", "chatbot")
        graph_builder.set_entry_point("chatbot")

        memory = MemorySaver()
        _graph = graph_builder.compile(
            checkpointer=memory,
            interrupt_before=["human"],
        )

    return _graph

# 初始化函数
async def initialize_agent( ):
    """初始化代理并返回编译后的图"""
    return await get_compiled_graph()

# 为了向后兼容，提供一个同步的图访问方式
graph = None  # 这将在运行时被设置

# 如果需要同步初始化，可以使用这个函数
def initialize_agent_sync():
    """同步初始化代理"""
    global graph
    if graph is None:
        graph = asyncio.run(initialize_agent())
    return graph